# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

import functools
import math

import pytest

import tvm.testing
from tvm import relax as rx
from tvm import tir
from tvm.relax.analysis import get_var2val
from tvm.relax.dpl import *
from tvm.script import relax as R
from tvm.script import tir as T


@tvm.script.ir_module
class Module:
    @T.prim_func
    def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
        T.func_attr({"global_symbol": "tir_matmul"})
        k = T.int32()
        A = T.match_buffer(x, (32, 32))
        B = T.match_buffer(y, (32, 32))
        C = T.match_buffer(z, (32, 32))

        for i0, j0, k0 in T.grid(32, 32, 32):
            with T.block():
                i, j, k = T.axis.remap("SSR", [i0, j0, k0])
                with T.init():
                    C[i, j] = 0.0
                C[i, j] += A[i, k] * B[j, k]

    @T.prim_func
    def tir_relu(x: T.handle, y: T.handle):
        T.func_attr({"global_symbol": "tir_relu"})
        A = T.match_buffer(x, (32, 32))
        B = T.match_buffer(y, (32, 32))
        for i, j in T.grid(32, 32):
            with T.block():
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = T.max(A[vi, vj], 0.0)

    @T.prim_func
    def tir_zeros(x: T.handle, n: T.int64):
        T.func_attr({"global_symbol": "tir_zeros"})
        A = T.match_buffer(x, [n])
        for i in range(n):
            with T.block():
                vi = T.axis.remap("S", [i])
                A[vi] = 1.0

    @R.function
    def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tuple:
        cls = Module
        with R.dataflow():
            lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32"))
            lv2 = R.call_tir(
                cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32])
            )
            gv = (lv1, lv2)
            R.output(gv)
        return gv


main_fn = Module["main"]
bindings = main_fn.body.blocks[0].bindings


## Node-wise Matching
def test_expr_pattern():
    ep = is_expr(rx.Var("x"))
    assert isinstance(ep, ExprPattern)
    assert isinstance(ep.expr, rx.Var)


def test_var_pattern():
    v = is_var("x")
    assert isinstance(v, VarPattern)
    assert v.name == "x"
    assert v.match(rx.Var("x"))
    assert is_var().match(rx.Var("x"))
    assert is_var().match(rx.DataflowVar("x"))  # DataflowVar is also a Var
    assert not v.match(rx.GlobalVar("x"))


def test_dataflow_var_pattern():
    v = is_dfv("x")
    assert isinstance(v, DataflowVarPattern)
    assert v.name == "x"
    assert v.match(rx.DataflowVar("x"))
    assert not v.match(rx.GlobalVar("x"))
    assert is_dfv().match(bindings[0].var)


def test_global_var_pattern():
    assert is_gv("x").match(rx.GlobalVar("x"))
    # TODO: disabled as regex is not supported due to
    # symbol conflict with PyTorch
    # assert is_gv("x.*").match(rx.GlobalVar("x_2"))
    assert is_gv().match(rx.GlobalVar("x"))
    assert not is_gv("x").match(rx.GlobalVar("y"))
    assert not is_gv("x").match(rx.Var("x"))


def test_constant_pattern():
    c = is_const()
    assert isinstance(c, ConstantPattern)
    assert c.match(rx.const([[0.1, 1.1, 2.1], [3.1, 4.1, 5.1]]))


def test_wildcard_pattern():
    wc = wildcard()
    assert isinstance(wc, WildcardPattern)
    assert wc.match(rx.Var("x"))


def test_call_pattern():
    wc1 = wildcard()
    wc2 = wildcard()
    c = is_op("relax.add")(wc1, wc2)
    assert isinstance(c, CallPattern)
    assert isinstance(c.args[0], WildcardPattern)
    assert isinstance(c.args[1], WildcardPattern)
    assert c.match(rx.op.add(rx.Var("x"), rx.Var("y")))


def test_function_pattern():
    wc1 = wildcard()
    wc2 = wildcard()
    f = FunctionPattern([wc1, wc2], is_op("relax.add")(wc1, wc2))
    assert isinstance(f, FunctionPattern)
    assert isinstance(f.params[0], WildcardPattern)
    assert isinstance(f.params[1], WildcardPattern)
    assert isinstance(f.body, CallPattern)
    assert isinstance(f.body.args[0], WildcardPattern)
    assert isinstance(f.body.args[1], WildcardPattern)
    x = rx.Var("x", R.Tensor("float32"))
    y = rx.Var("y", R.Tensor("float32"))
    assert f.match(rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32")))
    assert not f.match(
        rx.Function([x, y], rx.op.multiply(x, y), ret_struct_info=R.Tensor("float32"))
    )


def test_tuple_pattern():
    wc1 = wildcard()
    wc2 = is_dfv()
    t = is_tuple([wc1, wc2])
    assert isinstance(t, TuplePattern)
    assert isinstance(t.fields[0], WildcardPattern)
    assert isinstance(t.fields[1], DataflowVarPattern)
    assert t.match(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]))
    assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.GlobalVar("y")]))
    assert not t.match(rx.Tuple([]))
    assert t[0].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
    assert t[1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
    # Negative index is also allowed
    assert t[-1].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
    # None means any index.
    assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0))
    assert t[None].match(rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 1))
    with pytest.raises(IndexError):
        t[2]  # index cannot be greater than or equal to the tuple size.


def test_unordered_tuple_pattern():
    t = is_tuple([is_const(), is_dfv()], unordered=True)
    assert isinstance(t, UnorderedTuplePattern)
    assert isinstance(t.fields[0], ConstantPattern)
    assert isinstance(t.fields[1], DataflowVarPattern)
    assert t.match(rx.Tuple([rx.const([]), rx.DataflowVar("x")]))
    assert t.match(rx.Tuple([rx.DataflowVar("x"), rx.const([])]))
    assert not t.match(rx.Tuple([rx.DataflowVar("x"), rx.DataflowVar("y")]))
    assert not t.match(rx.Tuple([]))


def test_tuple_get_item_pattern():
    assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
        rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
    )
    assert is_tuple_get_item(is_tuple([is_gv("x"), is_dfv("y")]), 0).match(
        rx.TupleGetItem(rx.Tuple([rx.GlobalVar("x"), rx.DataflowVar("y")]), 0)
    )


def test_or_pattern():
    dfv_or_gv = is_dfv("x") | is_gv("x")
    assert isinstance(dfv_or_gv, OrPattern)
    assert dfv_or_gv.match(rx.DataflowVar("x"))
    assert dfv_or_gv.match(rx.GlobalVar("x"))
    assert not dfv_or_gv.match(rx.Var("x"))
    assert not dfv_or_gv.match(rx.DataflowVar("y"))
    assert not dfv_or_gv.match(rx.GlobalVar("y"))


def test_and_pattern():
    # float[2, 3, 3]
    f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32")
    assert isinstance(f32_233, AndPattern)
    assert f32_233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))
    assert not f32_233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
    assert not f32_233.match(rx.Var("x", R.Tensor("float32", ndim=3)))


def test_not_pattern():
    no_shape233 = ~wildcard().has_shape((2, 3, 3))
    assert isinstance(no_shape233, NotPattern)
    assert no_shape233.match(rx.Var("x", R.Tensor((3, 3, 3), "float32")))
    assert not no_shape233.match(rx.Var("x", R.Tensor((2, 3, 3), "float32")))


def test_dtype_pattern():
    dtype = "float16"
    pattern = has_dtype(dtype)
    assert isinstance(pattern, DataTypePattern)
    assert pattern.dtype == dtype
    assert has_dtype("float32").match(bindings[0].var)


def test_shape_pattern():
    shape = [32, 32]
    pattern = wildcard().has_shape(shape)
    assert isinstance(pattern, ShapePattern)
    tvm.ir.structural_equal(pattern.shape, shape)
    assert pattern.match(bindings[0].var)
    assert wildcard().has_shape([32, 32]).match(bindings[0].var)
    n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
    symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32"))
    assert wildcard().has_shape([n, m, n + m]).match(symsh_var)
    assert wildcard().has_shape([n, m, m + n]).match(symsh_var)  # + is commutative.
    assert not wildcard().has_shape([1, 2, 3]).match(symsh_var)
    assert not wildcard().has_shape([m, n, n + m]).match(symsh_var)


def test_prim_arr_pattern():
    """
    The difference between is_shape and has_shape is that:
    1) is_shape directly matches a shape (e.g., as an argument);
    2) has_shape matches a tensor and puts assumptions on the tensor's shape.
    """
    pattern = is_shape([32, 32])
    assert pattern[0] == 32
    assert pattern[1] == 32
    assert isinstance(pattern, PrimArrPattern)
    assert pattern.match(rx.get_shape_of(bindings[0].var))
    n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64")
    symbolic_shape = rx.ShapeExpr([n, m, n + m])
    assert is_shape([n, m, n + m]).match(symbolic_shape)
    assert not is_shape([n, m, n * m]).match(symbolic_shape)


def test_extern_fn_pattern():
    pattern = ExternFuncPattern("test.blockbuilder.nop")
    assert pattern.match(rx.ExternFunc("test.blockbuilder.nop"))


def test_op_attr():
    x = rx.Var("x", R.Tensor("float32"))
    y = rx.Var("y", R.Tensor("float32"))
    conv2d = rx.op.nn.conv2d(x, y, strides=(3, 3))
    xp = is_var("x")
    yp = is_var("y")
    # TODO(@yuchen): reenable the assert after figuring out why it fails
    # assert is_op("nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d)
    assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [4, 3]}).match(conv2d)
    assert not is_op("relax.nn.conv2d")(xp, yp).has_attr({"strides": [3, 3]}).match(conv2d)


def test_match_call_attr():
    x = rx.Var("x", R.Tensor("float32"))
    y = rx.Var("y", R.Tensor("float32"))
    fn = rx.Function([x, y], rx.op.add(x, y), ret_struct_info=R.Tensor("float32"))
    annotated_fn = fn.with_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"})
    xp = is_var("x")
    yp = is_var("y")
    root_pattern = FunctionPattern([xp, yp], is_op("relax.add")(xp, yp))
    assert root_pattern.has_attr({"Codegen": "test-codegen", "global_symbol": "test-symbol"}).match(
        annotated_fn
    )

    assert root_pattern.has_attr({"Codegen": "test-codegen"}).match(annotated_fn)
    assert not root_pattern.has_attr({"ping": "pong"}).match(annotated_fn)
    assert root_pattern.has_attr({}).match(annotated_fn)


def test_is_call_tir():
    lv1_val = bindings[1].value
    lv2_val = bindings[2].value
    var2val = get_var2val(Module["main"])
    assert is_call_tir("tir_relu").match(lv1_val)
    assert is_call_tir("tir_relu", [is_call_tir("tir_matmul")]).match(lv1_val, var2val=var2val)
    assert not is_call_tir("tir_relu", [is_call_tir("tir_relu")]).match(lv1_val, var2val=var2val)
    assert is_call_tir("tir_zeros", wildcard(), wildcard()).match(lv2_val, var2val=var2val)


@R.function(pure=False)
def simple_call_packed(
    x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")
) -> R.Tensor:
    gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32")))
    return gv0


def test_varg_default_wildcard():
    expr = simple_call_packed.body.blocks[0].bindings[0].value
    yes_pattern_explicit = ExternFuncPattern("test.vm.mul")(wildcard(), wildcard())
    yes_pattern_implicit = ExternFuncPattern("test.vm.mul")(varg_default_wildcard=True)
    no_pattern = ExternFuncPattern("test.vm.mul")(wildcard())

    assert yes_pattern_explicit.match(expr)
    assert yes_pattern_implicit.match(expr)
    assert not no_pattern.match(expr)


def test_simple_call_packed():
    expr = simple_call_packed.body.blocks[0].bindings[0].value
    assert is_call_packed("test.vm.mul").match(expr)
    assert is_call_packed("test.vm.mul", [is_var("x"), is_var("w")]).match(expr)


## Graph-wise Matching
def test_simple_used_by():
    with PatternContext() as ctx:
        n0 = is_var("x")  # x is a free var (fn arg)
        n1 = wildcard()
        n0 ^ n1
        dfb = main_fn.body.blocks[0]
        matched = ctx.match_dfb(dfb)
        assert matched
        assert matched[n0] == main_fn.params[0]
        assert matched[n1] == dfb.bindings[0].var


def test_simple_call_tir_edge():
    with PatternContext() as ctx:
        n0 = is_call_tir("tir_matmul")
        n1 = is_call_tir("tir_relu")
        n0.used_by(n1)
        dfb = main_fn.body.blocks[0]
        matched = ctx.match_dfb(dfb)
        assert matched
        assert matched[n0] == dfb.bindings[0].var
        assert matched[n1] == dfb.bindings[1].var


def test_simple_oub():
    with PatternContext() as ctx:
        n0 = is_call_tir("tir_matmul")
        n1 = is_call_tir("tir_relu")
        n0 >> n1
        dfb = main_fn.body.blocks[0]
        matched = ctx.match_dfb(dfb)
        assert matched
        assert matched[n0] == dfb.bindings[0].var
        assert matched[n1] == dfb.bindings[1].var


def test_counter_syntax_match():
    with PatternContext() as ctx:
        n0 = is_call_dps_packed("extern_matmul")
        n1 = is_call_dps_packed("extern_impossible")
        n0 >> n1
        dfb = main_fn.body.blocks[0]
        assert not ctx.match_dfb(dfb)

    with PatternContext() as ctx:
        n0 = is_call_dps_packed("extern_matmul")
        n1 = is_call_dps_packed("extern_impossible")
        n0 ^ n1
        dfb = main_fn.body.blocks[0]
        assert not ctx.match_dfb(dfb)


@tvm.script.ir_module
class Diamond:
    @R.function
    def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            #   matmul
            #  /      \
            # relu  sigmoid
            #  \      /
            #    add
            lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32"))
            lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
            lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32"))
            R.output(lv3)
        return lv3


def test_diamond():
    with PatternContext() as ctx:
        n0 = is_call_dps_packed("extern_matmul")
        n1 = is_call_dps_packed("extern_relu")
        n2 = is_call_dps_packed("extern_sigmoid")
        n3 = is_call_dps_packed("extern_add")

        n0 ^ n1
        n0 ^ n2
        n1 >> n3
        n2 >> n3

        dfb = Diamond["main"].body.blocks[0]

        assert ctx.match_dfb(dfb)
    # simplify it with fork_to
    with PatternContext() as ctx:
        n1 = is_call_dps_packed("extern_relu")
        n2 = is_call_dps_packed("extern_sigmoid")
        n3 = is_call_dps_packed("extern_add")

        is_call_dps_packed("extern_matmul").fork_to(n1, n2)
        n1 >> n3
        n2 >> n3

        dfb = Diamond["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)


def test_diamond_counter_oub():
    with PatternContext() as ctx:
        n0 = is_call_dps_packed("extern_matmul")
        n1 = is_call_dps_packed("extern_relu")
        n2 = is_call_dps_packed("extern_sigmoid")
        n3 = is_call_dps_packed("extern_add")

        n0 >> n1
        n0 >> n2
        n1 >> n3
        n2 >> n3

        dfb = Diamond["main"].body.blocks[0]
        assert not ctx.match_dfb(dfb)


@tvm.script.ir_module
class SmallDiamond:
    @R.function
    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            #    relu
            #  /      \
            #  \      /
            #    add
            lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32"))
            R.output(lv1)
        return lv1


@tvm.script.ir_module
class SmallParallel:
    @R.function
    def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            # relu   relu
            #   \    /
            #    add
            lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32"))
            lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32"))
            R.output(lv2)
        return lv2


def test_distinguish_diamond_and_parallel():
    # pattern lang cannot distinguish the two cases above.
    diamond = SmallDiamond["main"].body.blocks[0]
    parallel = SmallParallel["main"].body.blocks[0]

    with PatternContext() as ctx:
        # describe a diamond pattern
        fork = is_call_dps_packed("my_relu")
        join = is_call_dps_packed("my_add")
        fork.only_used_by(join, index=0)
        fork.only_used_by(join, index=1)

        assert ctx.match_dfb(diamond)
        assert not ctx.match_dfb(parallel)

    with PatternContext() as ctx:
        # describe a parallel pattern
        join = is_call_dps_packed("my_add")
        # Due to one-one matching:
        # is_call_dps_packed("my_relu") creates the 1st relu
        is_call_dps_packed("my_relu") >> join
        # is_call_dps_packed("my_relu")
        # creates the another different relu (obj address is different)
        is_call_dps_packed("my_relu") >> join

        assert ctx.match_dfb(parallel)
        assert not ctx.match_dfb(diamond)


@tvm.script.ir_module
class CBRx2:
    @R.function
    def main(
        x: R.Tensor((32, 32), "float32"),
        w0: R.Tensor((1, 1), "float32"),
        bias0: R.Tensor((32, 32), "float32"),
        w1: R.Tensor((1, 1), "float32"),
        bias1: R.Tensor((32, 32), "float32"),
    ) -> R.Tensor:
        # R.TensorRT's CBR Optimization Pattern
        #     input
        #     /   \
        #  cbr0   cbr1
        #     \   /
        #     concat
        with R.dataflow():
            lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32"))
            lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32"))
            lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32"))
            lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32"))
            lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32"))
            lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32"))
            R.output(lv6)
        return lv6


def test_nested_context():
    dfb = CBRx2["main"].body.blocks[0]
    with PatternContext() as ctx0:
        (
            is_call_dps_packed("conv1x1")
            >> is_call_dps_packed("bias_add")
            >> is_call_dps_packed("my_relu")
        )
        with PatternContext() as ctx1:
            is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu")  # pattern to miss
            with PatternContext() as ctx2:
                is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu")
                assert ctx2.match_dfb(dfb)
                assert PatternContext.current() == ctx2
            assert not ctx1.match_dfb(dfb)
            assert PatternContext.current() == ctx1
        assert ctx0.match_dfb(dfb)
        assert PatternContext.current() == ctx0


def test_two_cbr():
    with PatternContext() as ctx:
        cbr0 = (
            is_call_dps_packed("conv1x1")
            >> is_call_dps_packed("bias_add")
            >> is_call_dps_packed("my_relu")
        )
        cbr1 = cbr0.dup()

        assert cbr0.patterns[0] != cbr1.patterns[0]
        assert cbr0.patterns[1] != cbr1.patterns[1]
        assert cbr0.patterns[2] != cbr1.patterns[2]

        is_var("x").fork_to(cbr0, cbr1)
        dfb = CBRx2["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)

    with PatternContext() as ctx:
        # Deny the pattern
        cbr0 = (
            is_call_dps_packed("conv1x1")
            >> is_call_dps_packed("bias_add")
            >> is_call_dps_packed("my_relu")
        )
        cbr1 = cbr0.dup()

        # input has no fork at y.
        is_var("y").fork_to(cbr0, cbr1)
        dfb = CBRx2["main"].body.blocks[0]
        assert not ctx.match_dfb(dfb)


def test_two_matmul():
    # Same as Figure 2(a) in TASO paper.
    @tvm.script.ir_module
    class MatMul2:
        @R.function
        def main(
            a: R.Tensor((32, 16), "float32"),
            b: R.Tensor((16, 48), "float32"),
            c: R.Tensor((48, 32), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32"))
                lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32"))
                R.output(lv1)
            return lv1

    with PatternContext() as ctx:
        is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
        dfb = MatMul2["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)

    with PatternContext() as ctx:
        is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape(
            [32, 32]
        )
        dfb = MatMul2["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)

    with PatternContext() as ctx:
        is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul")
        dfb = MatMul2["main"].body.blocks[0]
        # Three MatMul cannot match
        assert not ctx.match_dfb(dfb)


def test_concat_mm_split():
    # Same as Figure 2(b) in TASO paper.
    @tvm.script.ir_module
    class CMS:
        @R.function
        def main(
            a: R.Tensor((32, 32), "float32"),
            b: R.Tensor((16, 32), "float32"),
            c: R.Tensor((16, 32), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32"))
                lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32"))
                lv2 = R.call_dps_packed(
                    "my_split",
                    (lv1,),
                    [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")],
                )
                lv3 = R.TupleGetItem(lv2, 0)
                lv4 = R.TupleGetItem(lv2, 1)
                lv5 = R.add(lv3, lv4)
                R.output(lv5)
            return lv5

    with PatternContext() as ctx:
        (
            is_call_dps_packed("my_concat")
            >> is_call_dps_packed("my_matmul")
            >> is_call_dps_packed("my_split")
        )
        dfb = CMS["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)

    with PatternContext() as ctx:
        split = is_call_dps_packed("my_split")
        lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32])
        lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32])
        split.fork_to(lv3, lv4)
        add = is_op("relax.add")(lv3, lv4)
        # TODO(@ganler): simplify this through implicit graph pattern.
        lv3 >> add
        lv4 >> add

        dfb = CMS["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)


def test_self_attention():
    # The example comes from.
    # https://developer.nvidia.com/blog/nlu-with-tensorrt-bert/
    @tvm.script.ir_module
    class SelfAttention:
        @R.function
        def main(
            x: R.Tensor(("b", "s", "n", "h"), "float32"),
            wq: R.Tensor(("h", "h"), "float32"),
            wk: R.Tensor(("h", "h"), "float32"),
            wv: R.Tensor(("h", "h"), "float32"),
        ) -> R.Tensor:
            b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64()
            with R.dataflow():
                fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32"))
                tpq = R.call_dps_packed(
                    "my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")
                )

                fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32"))
                tpk = R.call_dps_packed(
                    "my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")
                )

                mul = R.multiply(tpq, tpk)
                scale = R.multiply(mul, R.const(1.1, "float32"))
                softmax = R.call_dps_packed(
                    "softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")
                )

                fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32"))
                tpv = R.call_dps_packed(
                    "my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")
                )

                out = R.multiply(softmax, tpv)
                R.output(out)

            return out

    with PatternContext() as ctx:
        fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose")
        fc_trans_k = fc_trans_q.dup()
        fc_trans_v = fc_trans_q.dup()

        is_var("x").fork_to(fc_trans_q, fc_trans_k, fc_trans_v)
        dfb = SelfAttention["main"].body.blocks[0]
        assert ctx.match_dfb(dfb)


def test_nested_diamond():
    @tvm.script.ir_module
    class DiamondInDiamond:
        @R.function
        def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor:
            with R.dataflow():
                #   matmul0      matmul1
                #     /    \    /    \
                # sigmoid2  add4  sigmoid3
                #     \    /    \    /
                #      add5      add6
                #          \    /
                #           add7
                lv0 = R.call_dps_packed(
                    "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
                )
                lv1 = R.call_dps_packed(
                    "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")
                )
                lv2 = R.call_dps_packed(
                    "extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")
                )
                lv3 = R.call_dps_packed(
                    "extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")
                )
                lv4 = R.call_dps_packed(
                    "extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")
                )
                lv5 = R.call_dps_packed(
                    "extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")
                )
                lv6 = R.call_dps_packed(
                    "extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")
                )
                lv7 = R.call_dps_packed(
                    "extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")
                )
                R.output(lv7)
            return lv7

    # match matmul0 diamond
    with PatternContext() as ctx:
        sigmoid2 = is_call_dps_packed("extern_sigmoid")
        add4 = is_call_dps_packed("extern_add")
        is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
        add5 = is_call_dps_packed("extern_add")
        sigmoid2 >> add5
        add4 ^ add5
        assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])

    # counter case: mis-match matmul0 diamond
    with PatternContext() as ctx:
        sigmoid2 = is_call_dps_packed("extern_sigmoid")
        add4 = is_call_dps_packed("extern_add")
        is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4)
        add5 = is_call_dps_packed("extern_add")
        sigmoid2 >> add5
        add4 >> add5  # not only-used-by relation
        assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])

    # match matmul1 diamond
    with PatternContext() as ctx:
        sigmoid3 = is_call_dps_packed("extern_sigmoid")
        add4 = is_call_dps_packed("extern_add")
        is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4)
        add6 = is_call_dps_packed("extern_add")
        sigmoid3 >> add6
        add4 ^ add6
        assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])

    # match add-4-5-6-7
    with PatternContext() as ctx:
        add5, add6, add7 = (
            is_call_dps_packed("extern_add"),
            is_call_dps_packed("extern_add"),
            is_call_dps_packed("extern_add"),
        )
        is_call_dps_packed("extern_add").fork_to(add5, add6)  # add4
        add5 >> add7
        add6 >> add7
        assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0])


def test_incremental_solving():
    @R.function
    def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            # relu -> sigmoid -> neg
            lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32"))
            lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32"))
            R.output(lv2)
        return lv2

    relu = is_call_dps_packed("extern_relu")
    sigmoid = is_call_dps_packed("extern_sigmoid")
    neg = is_call_dps_packed("extern_neg")

    with PatternContext() as ctx0:
        relu >> sigmoid
        with PatternContext(incremental=True) as ctx1:
            # because we are doing incremental solving
            # relu >> sigmoid is still a constraint in this context.
            # that said the total constraint is:
            # relu >> sigmoid >> neg
            sigmoid >> neg
            assert ctx1.match_dfb(simple_chain.body.blocks[0])

        # match relue -> sigmoid
        assert ctx0.match_dfb(simple_chain.body.blocks[0])


def test_incremental_solving_counter():
    @R.function
    def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
        with R.dataflow():
            # sigmoid -> neg
            lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32"))
            lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32"))
            R.output(lv1)
        return lv1

    relu = is_call_dps_packed("extern_relu")
    sigmoid = is_call_dps_packed("extern_sigmoid")
    neg = is_call_dps_packed("extern_neg")

    with PatternContext() as ctx0:
        relu >> sigmoid  # cannot match

        with PatternContext(incremental=False) as ctx1:
            # total constraint: sigmoid >> neg
            sigmoid >> neg
            assert ctx1.match_dfb(simple_chain.body.blocks[0])

        with PatternContext(incremental=True) as ctx1:
            # total constraint: relu >> sigmoid >> neg
            sigmoid >> neg
            assert not ctx1.match_dfb(simple_chain.body.blocks[0])


def test_rewrite_simple():
    @R.function
    def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"):
        with R.dataflow():
            x2 = R.add(x, x)
            x4 = R.add(x2, x2)
            R.output(x4)
        return x4

    @R.function
    def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2, "float32"))
            x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv, R.const(2, "float32"))
            R.output(x4)
        return x4

    @R.function
    def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"):
        with R.dataflow():
            x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4, "float32"))
            R.output(x4)
        return x4

    x = wildcard()
    pattern = is_op("relax.add")(x, x)

    def rewriter(_, matchings):
        return R.multiply(matchings[x], R.const(2, "float32"))

    rewritten = rewrite_call(pattern, rewriter, main)
    tvm.ir.assert_structural_equal(rewritten, expected1.with_attr("global_symbol", "main"))

    add1 = is_op("relax.add")(x, x)
    pattern = is_op("relax.add")(add1, add1)

    def rewriter(_, matchings):
        return R.multiply(matchings[x], R.const(4, "float32"))

    rewritten = rewrite_call(pattern, rewriter, main)
    tvm.ir.assert_structural_equal(rewritten, expected2.with_attr("global_symbol", "main"))

    # No rewriting, return the original call node as is
    def rewriter(orig, _):
        return orig

    rewritten = rewrite_call(pattern, rewriter, main)
    tvm.ir.assert_structural_equal(rewritten, main)


def test_rewrite_attention():
    @R.function
    def main(
        Q: R.Tensor((2, 4096, 8, 40), "float32"),
        K: R.Tensor((2, 4096, 8, 40), "float32"),
        V: R.Tensor((2, 4096, 8, 40), "float32"),
    ) -> R.Tensor((2, 4096, 8, 40), "float32"):
        with R.dataflow():
            lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
            lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))

            lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
            lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))

            lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
            lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))

            lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
            lv3_1 = R.matmul(lv59, lv62_transposed)
            lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
            lv69 = R.nn.softmax(lv68, axis=-1)
            lv_3 = R.matmul(lv69, lv65)

            lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
            lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
            R.output(lv72)

        return lv72

    @R.function
    def expected(
        Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
        K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
        V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
    ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
        with R.dataflow():
            lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") = R.nn.attention(Q, V, K)
            R.output(lv72)
        return lv72

    def BSNH_to_BSH(tensor):
        return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor), wildcard())

    def BSH_to_BSNH(tensor):
        return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor, wildcard()))

    Q = wildcard()
    K = wildcard()
    V = wildcard()

    Q_3D = BSNH_to_BSH(Q)
    V_3D = BSNH_to_BSH(V)
    K_3D = BSNH_to_BSH(K)

    matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
    multiply = is_op("relax.multiply")(matmul1, is_const())
    softmax = is_op("relax.nn.softmax")(multiply)
    matmul2 = is_op("relax.matmul")(softmax, K_3D)

    pattern = BSH_to_BSNH(matmul2)

    def rewriter(_, matchings):
        return R.nn.attention(matchings[Q], matchings[K], matchings[V])

    rewritten = rewrite_call(pattern, rewriter, main)
    tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main"))


def test_attention_qkv():
    @tvm.script.ir_module
    class QKV_proj:
        @R.function
        def main(
            x: R.Tensor((2, 1024, 640), "float32"),
            w0: R.Tensor((640, 640), "float32"),
            w1: R.Tensor((640, 640), "float32"),
            w2: R.Tensor((640, 640), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                lv0 = R.matmul(x, w0)
                lv1 = R.matmul(x, w1)
                lv2 = R.matmul(x, w2)
                out = (lv0, lv1, lv2)
                R.output(out)
            return out

    with PatternContext() as ctx:
        inp_pat = wildcard()
        Q_weight_pat = wildcard()
        K_weight_pat = wildcard()
        V_weight_pat = wildcard()

        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

        dfb = QKV_proj["main"].body.blocks[0]
        out = ctx.match_dfb(dfb)

        assert out[Q_weight_pat].name_hint == "w0"
        assert out[K_weight_pat].name_hint == "w1"
        assert out[V_weight_pat].name_hint == "w2"


def test_attention_fake_qkv():
    @tvm.script.ir_module
    class QKV_proj:
        @R.function
        def main(
            x1: R.Tensor((2, 1024, 640), "float32"),
            x2: R.Tensor((2, 1024, 640), "float32"),
            w0: R.Tensor((640, 640), "float32"),
            w1: R.Tensor((640, 640), "float32"),
            w2: R.Tensor((640, 640), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                lv0 = R.matmul(x1, w0)
                lv1 = R.matmul(x2, w1)
                lv2 = R.matmul(x2, w2)
                out = (lv0, lv1, lv2)
                R.output(out)
            return out

    with PatternContext() as ctx:
        inp_pat = wildcard()
        Q_weight_pat = wildcard()
        K_weight_pat = wildcard()
        V_weight_pat = wildcard()

        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

        dfb = QKV_proj["main"].body.blocks[0]
        assert ctx.match_dfb(dfb) is None


def get_qkv_proj_rewriter():
    with PatternContext() as ctx:
        inp_pat = wildcard()
        Q_weight_pat = wildcard()
        K_weight_pat = wildcard()
        V_weight_pat = wildcard()

        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)

    def qkv_proj_rewriter(matchings, _):
        inp = matchings[inp_pat]
        Q_weight = matchings[Q_weight_pat]
        K_weight = matchings[K_weight_pat]
        V_weight = matchings[V_weight_pat]
        width = Q_weight.struct_info.shape[1]

        concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
        matmul = R.matmul(inp, concat)
        Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
        K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
        V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3])

        return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V}

    return ctx, qkv_proj_rewriter


def test_combine_matmul_twice():
    @R.function(private=True)
    def qkv_x2(
        x1: R.Tensor((2, 1024, 640), "float32"),
        x2: R.Tensor((2, 1024, 640), "float32"),
        w0: R.Tensor((640, 640), "float32"),
        w1: R.Tensor((640, 640), "float32"),
        w2: R.Tensor((640, 640), "float32"),
        w3: R.Tensor((640, 640), "float32"),
        w4: R.Tensor((640, 640), "float32"),
        w5: R.Tensor((640, 640), "float32"),
    ):
        with R.dataflow():
            lv0 = R.matmul(x1, w0)
            lv1 = R.matmul(x1, w1)
            lv2 = R.matmul(x1, w2)
            lv3 = R.matmul(x2, w3)
            lv4 = R.matmul(x2, w4)
            lv5 = R.matmul(x2, w5)
            out = (lv0, lv1, lv2, lv3, lv4, lv5)
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x1: R.Tensor((2, 1024, 640), "float32"),
        x2: R.Tensor((2, 1024, 640), "float32"),
        w0: R.Tensor((640, 640), "float32"),
        w1: R.Tensor((640, 640), "float32"),
        w2: R.Tensor((640, 640), "float32"),
        w3: R.Tensor((640, 640), "float32"),
        w4: R.Tensor((640, 640), "float32"),
        w5: R.Tensor((640, 640), "float32"),
    ):
        with R.dataflow():
            lv = R.concat((w0, w1, w2), axis=1)
            lv1 = R.matmul(x1, lv)
            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
            lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
            lv2_1 = R.concat((w3, w4, w5), axis=1)
            lv3 = R.matmul(x2, lv2_1, out_dtype="void")
            lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640])
            lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280])
            lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920])
            out = lv0, lv1_1, lv2, lv3_1, lv4, lv5
            R.output(out)
        return out

    ctx, rewriter = get_qkv_proj_rewriter()
    rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
    tvm.ir.assert_structural_equal(rewritten, expected)


def test_dataflow_may_start_with_match_cast():
    """Inputs to rewrite_bindings may contain R.match_cast

    This is a regression test.  In previous implementations, applying
    `rewrite_bindings` when `R.match_cast` is the first binding of a
    `R.dataflow` block would cause a segfault.

    """

    @R.function(private=True)
    def before(
        x_untyped: R.Tensor,
        w0_untyped: R.Tensor,
        w1_untyped: R.Tensor,
        w2_untyped: R.Tensor,
    ):
        with R.dataflow():
            x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
            w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
            w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
            w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
            out_0 = R.matmul(x, w0)
            out_1 = R.matmul(x, w1)
            out_2 = R.matmul(x, w2)
            out = (out_0, out_1, out_2)
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x_untyped: R.Tensor,
        w0_untyped: R.Tensor,
        w1_untyped: R.Tensor,
        w2_untyped: R.Tensor,
    ):
        with R.dataflow():
            x = R.match_cast(x_untyped, R.Tensor((2, 1024, 640), "float32"))
            w0 = R.match_cast(w0_untyped, R.Tensor((640, 640), "float32"))
            w1 = R.match_cast(w1_untyped, R.Tensor((640, 640), "float32"))
            w2 = R.match_cast(w2_untyped, R.Tensor((640, 640), "float32"))
            w_concat = R.concat((w0, w1, w2), axis=1)
            out_concat = R.matmul(x, w_concat)
            out_0 = R.strided_slice(out_concat, axes=[2], begin=[0], end=[640])
            out_1 = R.strided_slice(out_concat, axes=[2], begin=[640], end=[1280])
            out_2 = R.strided_slice(out_concat, axes=[2], begin=[1280], end=[1920])
            out = (out_0, out_1, out_2)
            R.output(out)
        return out

    ctx, rewriter = get_qkv_proj_rewriter()
    rewritten = rewrite_bindings(ctx, rewriter, before)
    tvm.ir.assert_structural_equal(rewritten, expected)


def test_combine_matmul_emit_order():
    @R.function(private=True)
    def main(
        x1: R.Tensor((2, 1024, 640), "float32"),
        w0: R.Tensor((640, 640), "float32"),
        w1: R.Tensor((640, 640), "float32"),
        w2: R.Tensor((640, 640), "float32"),
    ):
        with R.dataflow():
            w0_t = R.permute_dims(w0, axes=None)
            lv0 = R.matmul(x1, w0_t)
            w1_t = R.permute_dims(w1, axes=None)
            w1_t_t = R.permute_dims(w1_t, axes=None)
            lv1 = R.matmul(x1, w1_t_t)
            w2_t = R.permute_dims(w2, axes=None)
            lv2 = R.matmul(x1, w2_t)
            out = (lv0, lv1, lv2)
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x1: R.Tensor((2, 1024, 640), dtype="float32"),
        w0: R.Tensor((640, 640), dtype="float32"),
        w1: R.Tensor((640, 640), dtype="float32"),
        w2: R.Tensor((640, 640), dtype="float32"),
    ):
        with R.dataflow():
            w0_t = R.permute_dims(w0, axes=None)
            w1_t = R.permute_dims(w1, axes=None)
            w1_t_t = R.permute_dims(w1_t, axes=None)
            w2_t = R.permute_dims(w2, axes=None)
            lv = R.concat((w0_t, w1_t_t, w2_t), axis=1)
            lv1 = R.matmul(x1, lv, out_dtype="void")
            lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
            lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
            lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
            out = lv0, lv1_1, lv2
            R.output(out)
        return out

    ctx, rewriter = get_qkv_proj_rewriter()

    rewritten = rewrite_bindings(ctx, rewriter, main)
    tvm.ir.assert_structural_equal(rewritten, expected)

    # make sure it builds
    mod = tvm.IRModule()
    mod["main"] = rewritten

    tvm.compile(mod, target="llvm")


def test_combine_transposed_matmul_twice():
    @R.function(private=True)
    def main(
        x1: R.Tensor((2, 1024, 640), "float32"),
        x2: R.Tensor((2, 1024, 640), "float32"),
        w0: R.Tensor((640, 640), "float32"),
        w1: R.Tensor((640, 640), "float32"),
        w2: R.Tensor((640, 640), "float32"),
        w3: R.Tensor((640, 640), "float32"),
    ):
        with R.dataflow():
            w0_t = R.permute_dims(w0, axes=None)
            lv0 = R.matmul(x1, w0_t)
            w1_t = R.permute_dims(w1, axes=None)
            lv1 = R.matmul(x1, w1_t)
            w2_t = R.permute_dims(w2, axes=None)
            lv2 = R.matmul(x2, w2_t)
            w3_t = R.permute_dims(w3, axes=None)
            lv3 = R.matmul(x2, w3_t)
            out = (lv0, lv1, lv2, lv3)
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x1: R.Tensor((2, 1024, 640), dtype="float32"),
        x2: R.Tensor((2, 1024, 640), dtype="float32"),
        w0: R.Tensor((640, 640), dtype="float32"),
        w1: R.Tensor((640, 640), dtype="float32"),
        w2: R.Tensor((640, 640), dtype="float32"),
        w3: R.Tensor((640, 640), dtype="float32"),
    ):
        with R.dataflow():
            lv: R.Tensor((1280, 640), dtype="float32") = R.concat((w0, w1), axis=0)
            lv1: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv, axes=None)
            lv2: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(x1, lv1, out_dtype="void")
            lv3: R.Tuple(
                R.Tensor((2, 1024, 640), dtype="float32"),
                R.Tensor((2, 1024, 640), dtype="float32"),
            ) = R.split(lv2, indices_or_sections=[640], axis=-1)
            lv0: R.Tensor((2, 1024, 640), dtype="float32") = lv3[0]
            lv1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3[1]
            lv_1: R.Tensor((1280, 640), dtype="float32") = R.concat((w2, w3), axis=0)
            lv1_2: R.Tensor((640, 1280), dtype="float32") = R.permute_dims(lv_1, axes=None)
            lv2_1: R.Tensor((2, 1024, 1280), dtype="float32") = R.matmul(
                x2, lv1_2, out_dtype="void"
            )
            lv3_1: R.Tuple(
                R.Tensor((2, 1024, 640), dtype="float32"),
                R.Tensor((2, 1024, 640), dtype="float32"),
            ) = R.split(lv2_1, indices_or_sections=[640], axis=-1)
            lv2_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[0]
            lv3_1_1: R.Tensor((2, 1024, 640), dtype="float32") = lv3_1[1]
            out: R.Tuple(
                R.Tensor((2, 1024, 640), dtype="float32"),
                R.Tensor((2, 1024, 640), dtype="float32"),
                R.Tensor((2, 1024, 640), dtype="float32"),
                R.Tensor((2, 1024, 640), dtype="float32"),
            ) = (lv0, lv1_1, lv2_1_1, lv3_1_1)
            R.output(out)
        return out

    with PatternContext() as ctx:
        inp_pat = wildcard()
        w1_pat = wildcard()
        w2_pat = wildcard()
        matmul1 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w1_pat))
        matmul2 = is_op("relax.matmul")(inp_pat, is_op("relax.permute_dims")(w2_pat))

        def rewriter(matchings, _):
            inp = matchings[inp_pat]
            w1 = matchings[w1_pat]
            w2 = matchings[w2_pat]

            concat = R.concat([w1, w2], axis=0)
            matmul = R.matmul(inp, R.permute_dims(concat))
            sections = [w1.struct_info.shape[0]]

            chunks = R.split(matmul, sections, -1)

            return {
                matchings[matmul1]: chunks[0],
                matchings[matmul2]: chunks[1],
            }

        rewritten = rewrite_bindings(ctx, rewriter, main)
        tvm.ir.assert_structural_equal(rewritten, expected)

        # make sure it builds
        mod = tvm.IRModule()
        mod["main"] = rewritten
        print(mod)

        tvm.compile(mod, target="llvm")


def test_commutative_pattern_match():
    @R.function(private=True)
    def before(
        x: R.Tensor((1024,)),
    ):
        with R.dataflow():
            y = R.add(x, x)
            out = R.add(R.const(1.0), y)
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x: R.Tensor((1024,)),
    ):
        with R.dataflow():
            y = R.add(x, x)
            out = R.add(y, R.const(2.0))
            R.output(out)

        return out

    pattern_add = is_op("relax.add")
    pattern_mul = is_op("relax.multiply")
    pattern_op = pattern_add | pattern_mul
    pattern_arg = wildcard()
    pattern_const = is_const()

    pattern = pattern_op(pattern_arg, pattern_const)

    def rewriter(expr, matches):
        op = matches[pattern_op]
        arg = matches[pattern_arg]
        const = matches[pattern_const].data.numpy()
        if const.shape == tuple() and const[()] == 1.0:
            return rx.Call(op, [arg, rx.const(2.0)])
        else:
            return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)


def test_repeated_pattern_match():
    """rewrite_call should iterate until convergence"""

    @R.function(private=True)
    def before(
        x: R.Tensor((1024,)),
        y: R.Tensor((1024,)),
        z: R.Tensor((1024,)),
    ):
        with R.dataflow():
            a = R.add(x, y)
            b = R.add(a, z)
            out = R.multiply(b, R.const(5.0))
            R.output(out)
        return out

    @R.function(private=True)
    def expected(
        x: R.Tensor((1024,)),
        y: R.Tensor((1024,)),
        z: R.Tensor((1024,)),
    ):
        with R.dataflow():
            x = R.multiply(x, R.const(5.0))
            y = R.multiply(y, R.const(5.0))
            a = R.add(x, y)
            z = R.multiply(z, R.const(5.0))
            b = R.add(a, z)
            R.output(b)
        return b

    pattern_add_lhs = wildcard()
    pattern_add_rhs = wildcard()
    pattern_add = is_op("relax.add")(pattern_add_lhs, pattern_add_rhs)

    mul_const = is_const()
    pattern_mul = is_op("relax.multiply")(pattern_add, mul_const)

    pattern = pattern_mul

    def rewriter(_expr, matches):
        const = matches[mul_const]
        return (matches[pattern_add_lhs] * const) + (matches[pattern_add_rhs] * const)

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)


bind_to_dataflow_var = tvm.testing.parameter(
    by_dict={"var-to-var": False, "var-to-dataflow-var": True}
)


def test_rewrite_without_trivial_binding(bind_to_dataflow_var):
    """rewrite_call should avoid producing trivial "y = x" bindings

    This may not be possible in all cases, and follows the same
    rules as CanonicalizeBindings.  For example, a `relax.Var` is
    bound to a `relax.DataflowVar` may not be removed, to ensure
    that the `relax.DataflowVar` is only used within a
    `DataflowBlock`.
    """

    if bind_to_dataflow_var:

        @R.function(private=True)
        def before(x: R.Tensor((1024,))):
            with R.dataflow():
                a = R.add(x, x)
                b = R.reshape(a, (1024,))
                R.output(b)
            return b

        @R.function(private=True)
        def expected(x: R.Tensor((1024,))):
            with R.dataflow():
                b = R.add(x, x)
                R.output(b)
            return b

    else:

        @R.function(private=True)
        def before(x: R.Tensor((1024,))):
            a = R.add(x, x)
            b = R.reshape(a, (1024,))
            return b

        @R.function(private=True)
        def expected(x: R.Tensor((1024,))):
            a = R.add(x, x)
            return a

    pattern_arg = wildcard()
    pattern_shape_expr = wildcard()
    pattern = is_op("relax.reshape")(pattern_arg, pattern_shape_expr)

    def rewriter(expr, matches):
        arg = matches[pattern_arg]
        shape_expr = matches[pattern_shape_expr]

        if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):
            return arg
        else:
            return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)


same_shape_func_type = tvm.testing.parameter(
    "same_static_shape",
    "same_dynamic_shape",
    "different_static_shape",
    "different_dynamic_shape",
)


def test_same_shape_pattern(same_shape_func_type):
    if same_shape_func_type == "same_static_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor((1024, 128), "float32"),
            b: R.Tensor((1024, 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "same_dynamic_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor(("n", 128), "float32"),
            b: R.Tensor(("n", 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "different_static_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor((1024, 128), "float32"),
            b: R.Tensor((1, 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    elif same_shape_func_type == "different_dynamic_shape":

        @R.function(private=True)
        def func(
            a: R.Tensor(("n", 128), "float32"),
            b: R.Tensor(("m", 128), "float32"),
        ) -> R.Tensor:
            with R.dataflow():
                c = R.multiply(a, R.const(2.0))
                d = R.add(b, c)
                out = d
                R.output(out)
            return out

    else:
        raise ValueError(f"Unknown value of same_shape_func_type={same_shape_func_type}")

    with PatternContext() as ctx:
        pat_lhs = wildcard()
        pat_rhs = wildcard()
        pat_sum = is_op("relax.add")(pat_lhs, pat_rhs)
        pat_lhs.same_shape_as(pat_rhs)

    block = func.body.blocks[0]
    match = ctx.match_dfb(block)

    if "same" in same_shape_func_type:
        assert match
    else:
        assert match is None


def test_iterative_rewrite_without_trivial_binding():
    """Avoid introducing common sub-expressions

    Pattern replacement may produce the same intermediate, which
    should appear only once in the final result.
    """

    @R.function(private=True)
    def before(x: R.Tensor((1024,))):
        with R.dataflow():
            a = R.strided_slice(x, [0], [0], [512], [1])
            b = R.strided_slice(x, [0], [512], [1024], [1])
            c = R.add(a, b)
            R.output(c)
        return c

    @R.function(private=True)
    def expected(x: R.Tensor((1024,))):
        with R.dataflow():
            x_split = R.split(x, 2)
            a = x_split[0]
            b = x_split[1]
            c = R.add(a, b)
            R.output(c)
        return c

    pattern_arg = wildcard()
    pattern_axes = wildcard()
    pattern_begin = wildcard()
    pattern_end = wildcard()
    pattern_strides = wildcard()
    pattern = is_op("relax.strided_slice")(
        pattern_arg, pattern_axes, pattern_begin, pattern_end, pattern_strides
    )

    def rewriter(expr, matches):
        arg = matches[pattern_arg]
        axes = matches[pattern_axes]
        begin = matches[pattern_begin]
        end = matches[pattern_end]
        strides = matches[pattern_strides]
        strided_slice = matches[pattern]

        if arg.struct_info.shape is None:
            return expr

        if len(axes) != 1:
            return expr

        axis = axes[0].value
        begin = begin[0].value
        end = end[0].value
        stride = strides[0].value

        if stride != 1:
            return expr

        size = arg.struct_info.shape[0]
        if (
            isinstance(size, tir.IntImm)
            and isinstance(begin, tir.IntImm)
            and isinstance(end, tir.IntImm)
        ):
            size = size.value
            begin = begin.value
            end = end.value
        else:
            return expr

        gcd = functools.reduce(math.gcd, [begin, end, size])
        if (end - begin) // gcd == 1:
            return rx.op.split(arg, size // gcd)[begin // gcd]

        return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(after, expected)


def test_iterative_rewrite_with_removed_intermediates():
    """Pattern replacement may require canonicalization

    A pattern may replace a tuple returned by a function with a tuple
    whose contents are known by Relax.  In that case, canonicalization
    is required to unwrap the TupleGetItem instances into the known
    contents.

    This test case shows the intermediate results produced in the
    process of pattern-matching.
    """

    @R.function(private=True)
    def before(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
        with R.dataflow():
            c = R.concat([a, b])
            d = R.split(c, 2)
            e = d[0]
            f = d[1]
            g = R.add(a, e)
            h = R.add(f, g)
            R.output(h)
        return h

    # First pattern rewrite.  The concat/rewrite can be unwrapped, so
    # `d` is rewritten from `R.split(c, 2)` into `(a, b)`.
    #
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         c = R.concat([a, b])
    #         d = (a,b)
    #         e = d[0]
    #         f = d[1]
    #         g = R.add(a, e)
    #         h = R.add(f, g)
    #         R.output(h)

    # Canonicalization step.  Because `d` is known to be `(a,b)`,
    # canonicalization can rewrite `d[0]` into `a` and `d[1]` into
    # `b`.
    #
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         c = R.concat([a, b])
    #         d = (a,b)
    #         e = a
    #         f = b
    #         g = R.add(a, a)
    #         h = R.add(b, g)
    #         R.output(h)

    # Dead-code-elimination step.  This technically isn't required
    # until the pattern matching has converged, but performing it now
    # prevents testing for matches on dead code.
    #
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         g = R.add(a, a)
    #         h = R.add(b, g)
    #         R.output(h)

    # Second pattern-matching step.  Now, the `R.add(a,a)` can match
    # the other option in our pattern, and be rewritten as
    # `R.multiply(a,R.const(2))`.
    #
    # @R.function(private=True)
    # def intermediate(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
    #     with R.dataflow():
    #         g = R.multiply(a, R.const(2))
    #         h = R.add(b, g)
    #         R.output(h)

    # Canonicalization and dead-code-elimination are applied again,
    # but have no effect this time.

    @R.function(private=True)
    def expected(a: R.Tensor((1024,)), b: R.Tensor((1024,))):
        with R.dataflow():
            g = R.multiply(a, R.const(2))
            h = R.add(b, g)
            R.output(h)
        return h

    pat_args = wildcard()

    op_concat = is_op("relax.concat")
    pat_concat = op_concat(pat_args).has_attr({"axis": 0})

    op_split = is_op("relax.split")
    pat_split = op_split(pat_concat).has_attr({"axis": 0, "indices_or_sections": T.int64(2)})

    pat_unwrap_concat_split = pat_split

    pat_arg = wildcard()
    op_add = is_op("relax.add")
    pat_add_self = op_add(pat_arg, pat_arg)

    pattern = pat_unwrap_concat_split | pat_add_self

    def rewriter(expr, matches):
        if pat_unwrap_concat_split in matches:
            args = matches[pat_args]

            if len(args) == 2 and tvm.ir.structural_equal(args[0].struct_info, args[1].struct_info):
                return args

        elif pat_add_self in matches:
            arg = matches[pat_arg]
            return arg * rx.const(2)

        return expr

    after = rewrite_call(pattern, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_with_struct_info_updates_when_matching():
    """A DFPattern may be restricted to a specific StructInfo"""

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.multiply(A, B)

            R.output(C)
        return C

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_with_struct_info_is_no_op_when_not_matching():
    """StructInfoPattern requires the StructInfo provided

    Here, the pattern would match, expect that the function has
    `R.Tensor([16,32])`, and the pattern requires `R.Tensor([2,3])`.
    """

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            # This R.add has the same shape as the pattern, and will
            # be updated.
            A = R.zeros([16, 32], "int32")
            B = R.ones([16, 32], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    expected = before

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_struct_info_for_unknown_dtype():
    """TensorStructInfo with unknown dtype allows any dtype"""

    pat_lhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([2, 3]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.add(A, B)

            D = R.zeros([2, 3], "float32")
            E = R.ones([2, 3], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([2, 3], "int32")
            B = R.ones([2, 3], "int32")
            C = R.multiply(A, B)

            D = R.zeros([2, 3], "float32")
            E = R.ones([2, 3], "float32")
            F = R.multiply(D, E)

            output = (C, F)
            R.output(output)
        return output

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_wildcard_struct_info_with_symbolic_vars():
    """StructInfoPattern may define symbolic vars

    This test finds an elementwise `R.add`, while ignoring a
    broadcasted `R.add`.
    """

    m = tir.Var("m", "int64")
    n = tir.Var("n", "int64")

    pat_lhs = wildcard().has_struct_info(R.Tensor([m, n]))
    pat_rhs = wildcard().has_struct_info(R.Tensor([m, n]))
    pat = is_op("relax.add")(pat_lhs, pat_rhs)

    def rewriter(expr, matches):
        lhs = matches[pat_lhs]
        rhs = matches[pat_rhs]
        return rx.op.multiply(lhs, rhs)

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.zeros([64, 128], "int32")
            B = R.ones([64, 128], "int32")
            C = R.add(A, B)

            D = R.zeros([64, 128], "float32")
            E = R.ones([1, 128], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    @R.function(private=True)
    def expected():
        with R.dataflow():
            A = R.zeros([64, 128], "int32")
            B = R.ones([64, 128], "int32")
            C = R.multiply(A, B)

            D = R.zeros([64, 128], "float32")
            E = R.ones([1, 128], "float32")
            F = R.add(D, E)

            output = (C, F)
            R.output(output)
        return output

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_if_rewriter_returns_no_op():
    """Rewriter participates in the pattern matching

    Sometimes, the pattern-matching syntax is insufficient to check if
    a replacement may be performed.  In this case, the `rewriter`
    function may perform additional validation.  If this validation
    fails, the `rewriter` function can return the original expression,
    and no replacement is performed.

    In addition, when the `rewriter` returns the original expression,
    the pattern match should backtrack to determine if another branch
    of the match may have produced a replacement.

    This functionality allows pattern replacements to be composed.
    """

    pat_match_no_rewrite = is_op("relax.add")(wildcard(), wildcard())

    pat_arg = wildcard()
    pat_zeros = is_op("relax.zeros")(wildcard())
    pat_add = is_op("relax.add")(pat_arg, pat_zeros)

    # OR conditions are checked in the order that they occur.  Because
    # `pat_match_no_rewrite` is a superset of `pat_add`, it will
    # always match first.
    pat = pat_match_no_rewrite | pat_add

    def rewriter(expr, matches):
        if pat_match_no_rewrite in matches:
            # This branch simulates a rewrite whose precondition has
            # failed.  If the pattern-matching treats this as a
            # successful match with no replacemen required, then no
            # rewrite would be performed.  On the other hand, if the
            # pattern-matching treats this as an unsuccessful match,
            # then it can backtrack and attempt `pat_add` instead.
            return expr
        elif pat_add in matches:
            return matches[pat_arg]
        else:
            raise RuntimeError("Pattern matched, but neither branch matched")

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.ones([64, 128], "int32")
            B = R.zeros([64, 128], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    @R.function(private=True)
    def expected():
        with R.dataflow():
            C = R.ones([64, 128], "int32")

            R.output(C)
        return C

    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
    """The matches should always contain the bound value

    This is a regression test.  In versions from
    https://github.com/apache/tvm/pull/16732 to
    https://github.com/apache/tvm/pull/16828, the `rewrite_call`
    function could erroneously call the rewriter with `expr` and
    `matches[pat]` set to a variable (`C`) instead of the value to
    which it is bound (`R.add(A,B)`).
    """
    pat_a = is_op("relax.add")(wildcard(), wildcard())
    pat_b = is_op("relax.add")(wildcard(), wildcard())
    pat = pat_a | pat_b

    def rewriter(expr, matches):
        assert isinstance(matches[pat], rx.Call)
        return expr

    @R.function(private=True)
    def before():
        with R.dataflow():
            A = R.ones([64, 128], "int32")
            B = R.zeros([64, 128], "int32")
            C = R.add(A, B)

            R.output(C)
        return C

    expected = before
    after = rewrite_call(pat, rewriter, before)
    tvm.ir.assert_structural_equal(expected, after)


if __name__ == "__main__":
    tvm.testing.main()
